import torch
from model import EIModel
from dataSet import make_data1
import matplotlib.pyplot as plt

# 加载数据
X, Y = make_data1()

# 加载模型
model = EIModel()
model_state_dict = torch.load(r'model.pth')
model.load_state_dict(model_state_dict)

# 模型的验证
# 参数权值的提取
weight, bias = model.linear.weight, model.linear.bias
# 预测效果的展示
plt.scatter(X, Y)
plt.xlabel('Education')
plt.ylabel('Income')

predict = model(X).detach().numpy()
# print("预测值：" + predict.item())

plt.plot(X, predict, c='red')
plt.savefig('result.jpg')
plt.show()
